728x90

 

참고 코드:

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull


class KMeans(object):

    def __init__(self, k):
        self.k = k

    def fit(self, X, initial_centroid_index=None, max_iters=10, seed=16, plt_process=False):
        m, n = X.shape

        # 특별히 지정한 중심점이 없으면, 중심점을 랜덤으로 초기화
        if initial_centroid_index is None:
            np.random.seed(seed)
            initial_centroid_index = np.random.randint(0, m, self.k)

        centroid = X[initial_centroid_index, :]

        idx = None

        
        plt.ion()
        for i in range(max_iters):
            # 중심점을 기점으로 샘플을 분류
            idx = self.find_closest_centroids(X, centroid)

            if plt_process:
                self.plot_converge(X, idx, initial_centroid_index)

            # 중심점 다시 계산
            centroid = self.compute_centroids(X, idx)

        
        plt.ioff()
        plt.show()

        return centroid, idx

    def find_closest_centroids(self, X, centroid):

        
        distance = np.sum((X[:, np.newaxis, :] - centroid) ** 2, axis=2)
        idx = distance.argmin(axis=1)
        return idx

    def compute_centroids(self, X, idx):
        centroids = np.zeros((self.k, X.shape[1]))

        for i in range(self.k):
            centroids[i, :] = np.mean(X[idx == i], axis=0)

        return centroids

    def plot_converge(self, X, idx, initial_idx):
        plt.cla()  

        plt.title("k-meas converge process")
        plt.xlabel('density')
        plt.ylabel('sugar content')

        plt.scatter(X[:, 0], X[:, 1], c='lightcoral')
        
        plt.scatter(X[initial_idx, 0], X[initial_idx, 1], label='initial center', c='k')

       
        for i in range(self.k):
            X_i = X[idx == i]

           
            hull = ConvexHull(X_i).vertices.tolist()
            hull.append(hull[0])
            plt.plot(X_i[hull, 0], X_i[hull, 1], 'c--')

        plt.legend()
        plt.pause(0.5)


if __name__ == '__main__':

    data = np.loadtxt('..\data\watermelon4_0_Ch.txt', delimiter=', ')
    centroid, idx = KMeans(3).fit(data, plt_process=True, seed=24)

 

출처: CSDN @Messor2020

 

+ Recent posts